{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ensembling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightning import pytorch as pl\n",
    "import numpy as np\n",
    "import torch\n",
    "from chemprop import data, models, nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is an example [dataloader](./data/dataloaders.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "smis = [\"C\" * i for i in range(1, 4)]\n",
    "ys = np.random.rand(len(smis), 1)\n",
    "dset = data.MoleculeDataset([data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])\n",
    "dataloader = data.build_dataloader(dset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model ensembling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A single model will sometimes give erroneous predictions for some molecules. These erroneous predictions can be mitigated by averaging the predictions of several models trained on the same data. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ensemble = []\n",
    "n_models = 3\n",
    "for _ in range(n_models):\n",
    "    ensemble.append(models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), nn.RegressionFFN()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n",
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
      "\n",
      "  | Name            | Type               | Params | Mode \n",
      "---------------------------------------------------------------\n",
      "0 | message_passing | BondMessagePassing | 227 K  | train\n",
      "1 | agg             | MeanAggregation    | 0      | train\n",
      "2 | bn              | Identity           | 0      | train\n",
      "3 | predictor       | RegressionFFN      | 90.6 K | train\n",
      "4 | X_d_transform   | Identity           | 0      | train\n",
      "5 | metrics         | ModuleList         | 0      | train\n",
      "---------------------------------------------------------------\n",
      "318 K     Trainable params\n",
      "0         Non-trainable params\n",
      "318 K     Total params\n",
      "1.273     Total estimated model params size (MB)\n",
      "24        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 14.38it/s, train_loss_step=0.234, train_loss_epoch=0.234]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=1` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 13.86it/s, train_loss_step=0.234, train_loss_epoch=0.234]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "\n",
      "  | Name            | Type               | Params | Mode \n",
      "---------------------------------------------------------------\n",
      "0 | message_passing | BondMessagePassing | 227 K  | train\n",
      "1 | agg             | MeanAggregation    | 0      | train\n",
      "2 | bn              | Identity           | 0      | train\n",
      "3 | predictor       | RegressionFFN      | 90.6 K | train\n",
      "4 | X_d_transform   | Identity           | 0      | train\n",
      "5 | metrics         | ModuleList         | 0      | train\n",
      "---------------------------------------------------------------\n",
      "318 K     Trainable params\n",
      "0         Non-trainable params\n",
      "318 K     Total params\n",
      "1.273     Total estimated model params size (MB)\n",
      "24        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 46.40it/s, train_loss_step=0.215, train_loss_epoch=0.215]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=1` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 23.79it/s, train_loss_step=0.215, train_loss_epoch=0.215]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "\n",
      "  | Name            | Type               | Params | Mode \n",
      "---------------------------------------------------------------\n",
      "0 | message_passing | BondMessagePassing | 227 K  | train\n",
      "1 | agg             | MeanAggregation    | 0      | train\n",
      "2 | bn              | Identity           | 0      | train\n",
      "3 | predictor       | RegressionFFN      | 90.6 K | train\n",
      "4 | X_d_transform   | Identity           | 0      | train\n",
      "5 | metrics         | ModuleList         | 0      | train\n",
      "---------------------------------------------------------------\n",
      "318 K     Trainable params\n",
      "0         Non-trainable params\n",
      "318 K     Total params\n",
      "1.273     Total estimated model params size (MB)\n",
      "24        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 42.51it/s, train_loss_step=0.239, train_loss_epoch=0.239]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=1` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 36.88it/s, train_loss_step=0.239, train_loss_epoch=0.239]\n"
     ]
    }
   ],
   "source": [
    "for model in ensemble:\n",
    "    trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1)\n",
    "    trainer.fit(model, dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 83.86it/s] \n",
      "Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 82.63it/s]\n",
      "Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 68.94it/s] \n"
     ]
    }
   ],
   "source": [
    "prediction_dataloader = data.build_dataloader(dset, shuffle=False)\n",
    "predictions = []\n",
    "for model in ensemble:\n",
    "    predictions.append(torch.concat(trainer.predict(model, prediction_dataloader)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[0.0096],\n",
       "         [0.0008],\n",
       "         [0.0082]]),\n",
       " tensor([[0.0318],\n",
       "         [0.0260],\n",
       "         [0.0254]]),\n",
       " tensor([[-0.0054],\n",
       "         [ 0.0032],\n",
       "         [-0.0035]])]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0120],\n",
       "        [0.0100],\n",
       "        [0.0100]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.concat(predictions, axis=1).mean(axis=1, keepdim=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
